import numpy as np
import scipy
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import re
import time
import os
from utils import df1, df2, LMC
import argparse


if __name__ == '__main__':
    np.random.seed(0)

    parser = argparse.ArgumentParser()
    parser.add_argument("--potential")
    args = parser.parse_args()

    config = {
        'step size': 0,
        'num samples': int(1e5),
        'grad potential': None,
        'dimension': 10,
        'initial condition': 1,
        'T': 20,
        'stats function': None,
    }

    if args.potential == 'log-sum-exp':
        config['grad potential'] = df1
    elif args.potential == 'cosine':
        config['grad potential'] = lambda x: df2(x, config['dimension'])

    for h in np.linspace(0.1, 1, 10):
        config['step size'] = h

        true_mean = np.load(f"test_function_{args.potential}/benchmark-accurate/d={config['dimension']}.npy").mean(axis=0)
        config['stats function'] = lambda sample: np.linalg.norm( sample.mean(axis=0) - true_mean )

        start = time.time()
        hist = LMC(config)
        end = time.time()
        print(f'step size {h:.5f} finished: {end - start:.3f}s elapsed.')

        np.save(f'test_function_{args.potential}/h/h={h}.npy', hist)

        